Denoising Diffusion Implicit Models for the Oxford Flowers 102 Dataset¶
Introduction¶
Oxford Flowers 102 Dataset¶
https://www.robots.ox.ac.uk/~vgg/data/flowers/102/
102 category dataset, consisting of 102 flower categories. The flowers chosen to be flower commonly occuring in the United Kingdom. Each class consists of between 40 and 258 images. The images have large scale, pose and light variations. In addition, there are categories that have large variations within the category and several very similar categories.
There are a total of 8349 images, a total download in ~330Mb.
More information in the TensorFlow Datasets page:
https://www.tensorflow.org/datasets/catalog/oxford_flowers102
The package github.com/gomlx/gomlx/examples/oxfordflowers102
provides a train.Dataset object that can be used to train models with the dataset. It also provides a simplified
mechanism to download and cache the dataset.
Denoising Diffusion Implicit Models¶
This notebook is an example of a diffusion model, bsed on he Keras example in:
https://keras.io/examples/generative/ddim/
The modeling and training code is in github.com/gomlx/gomlx/examples/oxfordflowers102/diffusion.
There is a training binary for the commandline in the train/ subdirectory, it's a very small wrapper of the diffusion library.
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gomlx" "${HOME}/Projects/gonb" "${HOME}/Projects/gopjrt" "${HOME}/Projects/bsplines"
%goworkfix
- Added replace rule for module "github.com/gomlx/bsplines" to local directory "/home/janpf/Projects/bsplines". - Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx". - Added replace rule for module "github.com/janpfeifer/gonb" to local directory "/home/janpf/Projects/gonb". - Added replace rule for module "github.com/gomlx/gopjrt" to local directory "/home/janpf/Projects/gopjrt".
Hyperparameters and Configuration¶
Some basic parameters are set as flags, but everything else are set as parameters in the context.Contex. The hyperparameters can be configured using the --set command-line flag.
Below we define the ContextFromSettings that we are going to use everywhere, and we print out the available hyperparameters.
See diffusion.CreateDefaultContext for documentation on all hyperparameters.
import (
"github.com/gomlx/gomlx/examples/oxfordflowers102/diffusion"
"github.com/gomlx/gomlx/ml/context"
// Use XLA backend.
_ "github.com/gomlx/gomlx/backends/xla"
)
var (
flagDataDir = flag.String("data", "~/work/oxfordflowers102", "Directory to cache downloaded and generated dataset files.")
flagEval = flag.Bool("eval", true, "Whether to evaluate the model on the validation data in the end.")
flagVerbosity = flag.Int("verbosity", 1, "Level of verbosity, the higher the more verbose.")
flagCheckpoint = flag.String("checkpoint", "", "Directory save and load checkpoints from. If left empty, no checkpoints are created.")
// settings is bound to a "-set" flag to be used to set context hyperparameters.
settings = commandline.CreateContextSettingsFlag(imdb.CreateDefaultContext(), "set")
)
// ContextFromSettings is the default context (createDefaultContext) changed by -set flag.
// It requires that flags are already parsed.
//
// It also returns the list of parameters that were set.
func ContextFromSettings() (ctx *context.Context, paramsSet[]string) {
ctx = diffusion.CreateDefaultContext()
paramsSet = must.M1(commandline.ParseContextSettings(ctx, *settings))
return
}
// ConfigFromSettings returns a diffusion.Config object initilized from the settings.
// It requires that flags are already parsed.
func ConfigFromFlags() *diffusion.Config {
backend := backends.New()
ctx, paramsSet := ContextFromSettings()
return diffusion.NewConfig(backend, ctx, *flagDataDir, paramsSet)
}
%% -set="dtype=float32;train_steps=1_000"
c := ConfigFromFlags()
fmt.Println(commandline.SprintContextSettings(c.Context))
Context hyperparameters: "activation": (string) swish "adam_dtype": (string) "adam_epsilon": (float64) 1e-07 "adam_weight_decay": (float64) 0.0001 "batch_size": (int) 64 "checkpoint_frequency": (string) 3m "cosine_schedule_steps": (int) 0 "diffusion_channels_list": ([]int) [32 64 96 128] "diffusion_context_features": (bool) false "diffusion_loss": (string) mae "diffusion_max_signal_ratio": (float64) 0.95 "diffusion_min_signal_ratio": (float64) 0.02 "diffusion_num_residual_blocks": (int) 2 "dropout_rate": (float64) 0.15 "dtype": (string) float32 "eval_batch_size": (int) 128 "flower_type_embed_size": (int) 16 "huber_delta": (float64) 0.2 "image_size": (int) 64 "kid": (bool) false "l1_regularization": (float64) 0 "l2_regularization": (float64) 0 "learning_rate": (float64) 0.001 "model": (string) bow "nan_logger": (bool) false "normalization": (string) layer "num_checkpoints": (int) 5 "optimizer": (string) adam "plots": (bool) true "rng_reset": (bool) true "samples_during_training": (int) 64 "samples_during_training_frequency": (int) 200 "samples_during_training_frequency_growth": (float64) 1.2 "sinusoidal_embed_size": (int) 32 "sinusoidal_max_freq": (float64) 1000 "sinusoidal_min_freq": (float64) 1 "train_steps": (int) 1000
import (
"flag"
flowers "github.com/gomlx/gomlx/examples/oxfordflowers102"
"github.com/janpfeifer/must"
)
%%
c := ConfigFromFlags()
must.M(flowers.DownloadAndParse(c.DataDir))
fmt.Println("Oxford Flowers 102 dataset downloaded.")
Oxford Flowers 102 dataset downloaded.
Sample of Flowers¶
To do that we create a temporry dataset (with NewDataset) of size 256x256 pixels, and then show a sample of the flowers.
Later we will use a model that uses only 64x64 pixels.
import (
timage "github.com/gomlx/gomlx/types/tensors/images"
)
// sampleTable generates and outputs one html table of samples, sampling rows x cols from the images/labels provided.
func sampleTable(title string, ds train.Dataset, rows, cols int) {
htmlRows := make([]string, 0, rows)
for row := 0; row < rows; row++ {
cells := make([]string, 0, cols)
for col := 0; col < cols; col++ {
cells = append(cells, sampleOneImage(ds))
}
htmlRows = append(htmlRows, fmt.Sprintf("<tr>\n\t<td>%s</td>\n</tr>", strings.Join(cells, "</td>\n\t<td>")))
}
htmlTable := fmt.Sprintf("<h4>%s</h4><table>%s</table>\n", title, strings.Join(htmlRows, ""))
gonbui.DisplayHTML(htmlTable)
}
// sampleOneImage one image from tensor and returns an HTML rendered image with label
func sampleOneImage(ds train.Dataset) string {
_, inputs, labels := must.M3(ds.Yield())
imgTensor := inputs[0]
img := timage.ToImage().Single(imgTensor)
exampleNum := inputs[1].Value().(int64)
label := labels[0].Value().(int32)
labelStr := flowers.Names[label]
imgSrc := must.M1(gonbui.EmbedImageAsPNGSrc(img))
size := imgTensor.Shape().Dimensions[0]
return fmt.Sprintf(`<figure style="padding:4px;text-align: center;"><img width="%d" height="%d" src="%s"><figcaption style="text-align: center;">Example %d:<br/><span>%s (%d)</span><br/>(%dx%d pixels)</figcaption></figure>`,
size, size, imgSrc, exampleNum, labelStr, label, img.Bounds().Dx(), img.Bounds().Dy())
}
%% --set="image_size=256"
c := ConfigFromFlags()
must.M(flowers.DownloadAndParse(c.DataDir))
ds := flowers.NewDataset(dtypes.U8, c.ImageSize)
ds.Shuffle()
sampleTable("Oxford 102 Flowers Sample", ds, 4, 6)
Oxford 102 Flowers Sample
magnolia (86) (256x256 pixels) |
foxglove (93) (256x256 pixels) |
grape hyacinth (24) (256x256 pixels) |
canna lily (89) (256x256 pixels) |
globe-flower (15) (256x256 pixels) |
petunia (50) (256x256 pixels) |
primula (52) (256x256 pixels) |
water lily (72) (256x256 pixels) |
lenten rose (39) (256x256 pixels) |
hard-leaved pocket orchid (1) (256x256 pixels) |
prince of wales feathers (26) (256x256 pixels) |
foxglove (93) (256x256 pixels) |
pelargonium (54) (256x256 pixels) |
petunia (50) (256x256 pixels) |
bougainvillea (94) (256x256 pixels) |
gaura (56) (256x256 pixels) |
wallflower (45) (256x256 pixels) |
frangipani (80) (256x256 pixels) |
poinsettia (43) (256x256 pixels) |
globe thistle (9) (256x256 pixels) |
bishop of llandaff (55) (256x256 pixels) |
common dandelion (49) (256x256 pixels) |
spear thistle (13) (256x256 pixels) |
king protea (12) (256x256 pixels) |
In-Memory Dataset for Fast Access¶
We convert the flowers dataset to InMemory, and cache its contents for faster start-up time.
The first time it runs it will read and convert all images to the target size. But it then saves a cache of the generated content, so the second time it is faster.
From a local benchmark (go test -bench=. -test.run=Benchmark, with --batch=64):
- Directly reading (and parsing) from disk: ~215 ms/batch.
- Parallelized (24 cores) reading form disk: ~25 ms/batch.
- InMemory batches in GPU: ~41 µs/batch.
// Remove cached file to force regeneratrion.
!rm -f "${HOME}/work/oxfordflowers102/"*_cached_images_*
%%
c := ConfigFromFlags()
trainDS, validationDS := c.CreateInMemoryDatasets()
fmt.Println()
fmt.Printf("Total number of examples: #train=%d, #validation=%d\n", trainDS.NumExamples(), validationDS.NumExamples())
fmt.Printf("trainDS (in-memory) using %s of memory.\n", data.ByteCountIEC(trainDS.Memory()))
fmt.Printf("validationDS (in-memory) using %s of memory.\n", data.ByteCountIEC(validationDS.Memory()))
// Output a random sample.
trainDS.Shuffle()
sampleTable("Oxford 102 Flowers Sample -- In-Memory Dataset", trainDS, 1, 6)
Creating InMemoryDataset for "train" with images cropped and scaled to 64x64... - 3.446209827s to process dataset. Creating InMemoryDataset for "validation" with images cropped and scaled to 64x64... - 1.035522949s to process dataset. Total number of examples: #train=6487, #validation=1702 trainDS (in-memory) using 76.1 MiB of memory. validationDS (in-memory) using 20.0 MiB of memory.
Oxford 102 Flowers Sample -- In-Memory Dataset
gazania (70) (64x64 pixels) |
common dandelion (49) (64x64 pixels) |
wallflower (45) (64x64 pixels) |
petunia (50) (64x64 pixels) |
pink-yellow dahlia? (59) (64x64 pixels) |
cyclamen (87) (64x64 pixels) |
Denoising Diffusion Implicit Model¶
Preprocessing of images¶
The diffusion model takes images in normalized to a mean of 0 and standard deviation of 1, and generates images in the same range.
The functions PreprocessImage and DenormalizeImage converts to floats and normalize/denormalize them.
Below we quickly test that calling PreprocessImage and then DenormalizeImage has no effect on a random batch of images.
import (
. "github.com/gomlx/gomlx/graph"
"github.com/gomlx/gomlx/ml/context"
"github.com/gomlx/gomlx/ml/data"
"github.com/janpfeifer/gonb/gonbui"
)
var _ = NewGraph // Avoid warnings of non-used import.
%% --set="image_size=128"
c := ConfigFromFlags()
mean, stddev := c.NormalizationValues()
fmt.Printf("Flower images, per channel (red, green, blue):\n\t mean=%v\n\tstddev=%v\n", mean.Value(), stddev.Value())
trainDS, _ := c.CreateInMemoryDatasets()
trainDS.Shuffle()
trainDS.BatchSize(6, true)
_, inputs, _ := must.M3(trainDS.Yield())
gonbui.DisplayHTML("<p><b>Original:</b></p>")
diffusion.PlotImagesTensor(inputs[0])
e := NewExec(c.Backend, func(images *Node) *Node {
images = c.PreprocessImages(images, true)
images = c.DenormalizeImages(images)
return images
})
gonbui.DisplayHTML("<p><b>After normalization and denormalization:</b></p>")
imagesT := e.Call(inputs[0])[0]
fmt.Printf("imagesT.shape=%s\n", imagesT.Shape())
diffusion.PlotImagesTensor(imagesT)
Flower images, per channel (red, green, blue): mean=[[[[121.027176 100.22015 78.19373]]]] stddev=[[[[75.55288 62.183628 69.93133]]]] Creating InMemoryDataset for "train" with images cropped and scaled to 128x128... - 5.695557788s to process dataset. Creating InMemoryDataset for "validation" with images cropped and scaled to 128x128... - 1.794671756s to process dataset.
Original:
After normalization and denormalization:
imagesT.shape=(Float32)[6 128 128 3]
Sinusoidal Embedding¶
Used to embed the variance of the noise at different frequencies.
import (
"github.com/gomlx/gomlx/examples/oxfordflowers102/diffusion"
. "github.com/gomlx/gomlx/graph"
)
%%
c := ConfigFromFlags()
value := NewExec(c.Backend, func (x *Node) *Node {
return diffusion.SinusoidalEmbedding(c.Context, x)
}).Call(float32(1.0))[0]
fmt.Printf("SinusoidalEmbedding(1.0)=\n\tShape: %s\n\tValue: %v\n", value.Shape(), value.Value())
SinusoidalEmbedding(1.0)= Shape: (Float32)[32] Value: [1.7484555e-07 -0.5084644 -0.074616365 -0.11864995 0.93075866 2.70213e-06 -0.8129459 0.6793376 -0.92810476 0.5659511 1.176251e-05 0.06701087 0.9265504 0.62322754 -0.26391345 0.0007279766 1 -0.86108303 -0.9972123 0.99293613 -0.3656342 1 0.5823393 0.73382586 0.37231916 0.8244388 1 -0.99775225 0.37617072 0.7820406 0.9645464 0.99999976]
U-Net Model¶
The code in diffusion.UNetModelGraph follows the Keras example's Network Architecture.
The following the modeling functions:
UNetModelGraphbuilds a noisy image to (predicted image, predicted noise) computation graph, using U-Net model, it's the core of this example.BuildTrainingModelGraphbuilds thetrain.ModelFn(the function that GoMLX uses for a training loop). It takes raw images as examples, adds some random noise, at a random time (from 0.0 to 1.0), and uses the U-Net model to try to separate the noise. It returns the predicted image and the loss, where the loss is measured on the predicted noise -- it is better for learning than predicting the original image (*)
(*) My hypothesis is that predicting the original image is easier to overfit (since we have a limited number of images, but infinite noises we can generate).
The model in its default configuration uses ~3.5 million parameters:
import (
. "github.com/gomlx/gomlx/graph"
"github.com/gomlx/gomlx/ml/context"
"github.com/gomlx/gomlx/types/shapes"
"github.com/gomlx/gopjrt/dtypes"
)
// batch_size=5 just for testing.
%% --set="flower_type_embed_size=16;batch_size=5"
c := ConfigFromFlags()
fmt.Printf("Backend %q: %s\n", c.Backend.Name(), c.Backend.Description())
fmt.Println("\nUNetModelGraph:")
g := NewGraph(c.Backend, "test")
noisyImages := Zeros(g, shapes.Make(c.DType, c.BatchSize, 64, 64, 3))
flowerIds := Zeros(g, shapes.Make(dtypes.Int32, c.BatchSize))
fmt.Printf(" noisyImages.shape:\t%s\n", noisyImages.Shape())
filtered := diffusion.UNetModelGraph(c.Context, noisyImages, Ones(g, shapes.Make(c.DType, 5, 1, 1, 1)), flowerIds)
fmt.Printf(" filtered.shape:\t%s\n", filtered.Shape())
fmt.Printf("U-Net Model #params:\t%d\n", c.Context.NumParameters())
fmt.Printf(" U-Net Model memory:\t%s\n", data.ByteCountIEC(c.Context.Memory()))
fmt.Println("\nModelGraph:")
images := Zeros(g, shapes.Make(c.DType, 5, c.ImageSize, c.ImageSize, 3))
fmt.Printf(" images.shape:\t%s\n", images.Shape())
modelFn := c.BuildTrainingModelGraph()
predictions := modelFn(c.Context.Reuse(), nil, []*Node{images, nil, flowerIds})
fmt.Printf("predictedImages.shape:\t%s\n", predictions[0].Shape())
fmt.Printf(" loss.shape:\t%s\n", predictions[1].Shape())
fmt.Printf(" Model #params:\t%d\n", c.Context.NumParameters())
fmt.Printf(" Model memory:\t%s\n", data.ByteCountIEC(c.Context.Memory()))
Backend "xla": xla:cuda - PJRT "cuda" plugin (/usr/local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.54
UNetModelGraph:
noisyImages.shape: (Float32)[5 64 64 3]
filtered.shape: (Float32)[5 64 64 3]
U-Net Model #params: 3549734
U-Net Model memory: 13.5 MiB
ModelGraph:
images.shape: (Float32)[5 64 64 3]
predictedImages.shape: (Float32)[5 64 64 3]
loss.shape: (Float32)
Model #params: 3549734
Model memory: 13.5 MiB
Training Model¶
The training was mostly done from the command line -- easier to leave it running for hours -- using the train program.
It can be installed with go install github.com/gomlx/gomlx/examples/oxfordflowers102/diffusion/train@latest, but usually
I would just go to the directory and do a go run . <...flags...>, see some examples below.
It does also requires github.com/gomlx/gopjrt installed with the PJRT plugins
for your device (CPU, GPU, etc.).
The train program (and the library function) saves evaluation points as it is training, and these can be plotted, see below for an example.
If the training is interrupted and restarted, it continues where it left of.
See also github.com/gomlx/gomlx/cmd/gomlx_checkpoints to pretty-print the metrics, a model summary, its variables and its hyperparameters from the command line.
During the writing of this notebook, a few such models were generated, with different hyperparameters. But the space was not explored a lot -- if someone with more GPU available is willing to try some hyperaparameter tuning or have a better diffusion model to use, please share!
Training from the notebook¶
Because it takes many hours, we recommend training from the command line. But it can be done from the notebook as well -- it uses the same training function -- which can be useful for testing and development.
!rm -rf ~/work/oxfordflowers102/test/
%% --checkpoint "test" --set="train_steps=2000;plots=true;diffusion_num_residual_blocks=2"
c := ConfigFromFlags()
fmt.Printf("Backend %q: %s\n", c.Backend.Name(), c.Backend.Description())
diffusion.TrainModel(c.Context, *flagDataDir, *flagCheckpoint, c.ParamsSet, *flagEval, *flagVerbosity)
Backend "xla": xla:cuda - PJRT "cuda" plugin (/usr/local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.54 Backend "xla": xla:cuda - PJRT "cuda" plugin (/usr/local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.54 Checkpoint: "/home/janpf/work/oxfordflowers102/test" train_steps=2000 plots=true diffusion_num_residual_blocks=2 Loss: mae Learning rate: 0.001000
Training (2000 steps): 22% [=======>................................] (7 steps/s) [1m18s:3m30s] [step=439] [loss+=0.230] [~loss+=0.365] [~loss=0.365] [~img_loss=0.581]
Training (2000 steps): 100% [========================================] (6 steps/s) [step=1999] [loss+=0.189] [~loss+=0.197] [~loss=0.197] [~img_loss=0.208]
Metric: img_loss
Metric: loss
[Step 2000] median train step: 136505 microseconds Results on train: Mean Loss+Regularization (#loss+): 0.195 Mean Loss (#loss): 0.195 Images Loss (img_loss): 0.203 Results on validation: Mean Loss+Regularization (#loss+): 0.195 Mean Loss (#loss): 0.195 Images Loss (img_loss): 0.208
Generating images from a model¶
To generate images, we build and execute the model on various step. Each assuming less noise and more signal. It starts with purely randomly generated noise.
The function GenerateImages(numImages, numSteps, displayEveryNSteps) orchestrate this for us.
This only works with a trained model saved to a checkpoint.
%% --checkpoint="v5_64x64"
c := ConfigFromFlags()
_, _, _ = c.AttachCheckpoint(*flagCheckpoint)
c.DisplayImagesAcrossDiffusionSteps(12, 20, 10)
Checkpoint: "/home/janpf/work/oxfordflowers102/v5_64x64" DisplayImagesAcrossDiffusionSteps(12 images, 20 steps): noise.shape=(Float32)[12 64 64 3] Model #params: 5496166 Model memory: 21.0 MiB
Noise
5.0% Denoised -- Step 1/20
55.0% Denoised -- Step 11/20
100.0% Denoised -- Step 20/20
Generating 102 Images For Each Flower Type, Same Noise¶
Notice our current model is not very good yet: it's +/- the same image.
%% --checkpoint="v5_64x64"
c := ConfigFromFlags()
_, _, _ = c.AttachCheckpoint(*flagCheckpoint)
diffusion.PlotImagesTensor(c.GenerateImagesOfAllFlowerTypes(20))
Checkpoint: "/home/janpf/work/oxfordflowers102/v5_64x64"
Fixed starting noise, generated flower at different timesteps¶
At the start of each new model, the diffusion.TrainModel() function saves a set of random noise, and every time that the model is evaluated, images at the corresponding checkpoint at generated using these fixed noise. This way one can see how the odel training evaluates.
Below is the list of images generated by our v5_64x64 at various steps during training:
%% --checkpoint="v5_64x64"
c := ConfigFromFlags()
_, _, _ = c.AttachCheckpoint(*flagCheckpoint)
c.PlotModelEvolution(10, false)
Checkpoint: "/home/janpf/work/oxfordflowers102/v5_64x64"
Generated samples in /home/janpf/work/oxfordflowers102/v5_64x64:
- global_step 200:
- global_step 440:
- global_step 728:
- global_step 1074:
- global_step 1489:
- global_step 1987:
- global_step 2585:
- global_step 3303:
- global_step 4165:
- global_step 5199:
- global_step 6440:
- global_step 7929:
- global_step 9716:
- global_step 11860:
- global_step 14433:
- global_step 17521:
- global_step 21227:
- global_step 25674:
- global_step 31010:
- global_step 37413:
- global_step 45097:
Older Version¶
Below are results using an older version of GoMLX: the models trained then were better, but I didn't retrain them on the newer version.
TODO: improve them either with more training time, more hyperparameter tuning or a better model.
Larger 128x128 model with Transformer blocks¶
Belwo some random images generated by a model trained with 128x128, with more blocks and in the middle with 4 attention layers.
Also I tried the mean squared loss function.
%% --checkpoint="model_128x128_01" --size=128 --att_layers=4 --blocks=6 --norm=layer --activation=sigmoid --channels_list=16,32,64,96,128 --loss=mse --checkpoint_mean=-1
diffusion.PlotImagesTensor(diffusion.GenerateImages(90, 20, 0))
Model conditioned on flower types¶
The flag --flower_type_dim=16 will use the flower type (flowerIds) as a feature, and embed it with the given dimension at the start of each block.
We trained the model model_64x64_02 with 200K steps using this flag:
$ go run . --steps=200000 --plots --checkpoint="model_64x64_02" --norm=layer --learning_rate=1e-3 --flower_type_dim=16
Below the generation output for a few random examples for a few random flower types:
%% --checkpoint="model_64x64_02" --norm=layer --learning_rate=1e-3 --flower_type_dim=16 --checkpoint_mean=-1
for ii := 0; ii < 5; ii++ {
flowerType := int32(rand.Intn(flowers.NumLabels))
gonbui.DisplayHTML(fmt.Sprintf("<p>Generated <b>%s</b></p>\n", flowers.Names[flowerType]))
diffusion.PlotImagesTensor(diffusion.GenerateImagesOfFlowerType(18, flowerType, 30))
}
Generated gazania
Generated lotus
Generated alpine sea holly
Generated wallflower
Generated thorn apple
One starting noise, different flower types¶
%% --checkpoint="model_64x64_02" --norm=layer --learning_rate=1e-3 --flower_type_dim=16 --checkpoint_mean=-1
diffusion.PlotImagesTensor(diffusion.GenerateImagesOfAllFlowerTypes(20))
%% --checkpoint="model_64x64_02" --norm=layer --learning_rate=1e-3 --flower_type_dim=16 --checkpoint_mean=-1
// Load model
ctx := context.NewContext(manager).Checked(false)
_, _, _ = diffusion.LoadCheckpointToContext(ctx)
ctx.RngStateReset()
// Create UI with diffusion generated flowers.
divId := dom.CreateTransientDiv()
// cache.ResetKey("slider_diffusion_steps")
doneSteps := diffusion.SliderDiffusionSteps("slider_diffusion_steps", ctx, 8, 30, divId)
// cache.ResetKey("dropdown_flower_types")
doneFlowers := diffusion.DropdownFlowerTypes("dropdown_flower_types", ctx, 8, 20, divId)
// Wait for OK button.
button := widgets.Button("Ok").AppendTo(divId).Done()
<-button.Listen().C
// Clean up and persist HTML (so it can be saved).
doneSteps.Trigger()
doneFlowers.Trigger()
dom.Persist(divId)
%% --checkpoint="model_64x64_02" --norm=layer --learning_rate=1e-3 --flower_type_dim=16 --checkpoint_mean=-1
// Load model
ctx := context.NewContext(manager).Checked(false)
_, _, _ = diffusion.LoadCheckpointToContext(ctx)
// Create UI with diffusion generated flowers.
divId := dom.CreateTransientDiv()
diffusion.SliderDiffusionSteps("slider_diffusion_steps", ctx, 8, 30, divId)
diffusion.DropdownFlowerTypes("dropdown_flower_types", ctx, 8, 20, divId)
// Wait for OK button.
button := widgets.Button("Ok").AppendTo(divId).Done()
<-button.Listen().C